import os
import torch
import transformers
from peft import (
    LoraConfig,
    get_peft_model,
    PeftModel,
    prepare_model_for_kbit_training,
    PeftConfig
    )
import json
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from argparse import ArgumentParser
import logging
import random


def get_logger(opt):
    formatter = logging.Formatter('%(asctime)s %(levelname)-8s: %(message)s')
    logging.basicConfig(format='%(asctime)s %(levelname)-8s: %(message)s', level=logging.INFO)
    logger = logging.getLogger(opt.log)

    file_path = opt.log
    file_handler = logging.FileHandler(file_path)
    file_handler.setFormatter(formatter)
    file_handler.setLevel(logging.INFO)
    logger.addHandler(file_handler)
    return logger


def tokenize(prompt, add_eos_token=True):
    result = tokenizer(
            prompt,
            truncation=True,
            padding=False,
            return_tensors=None,
            max_length=512
        )
    if result["input_ids"][-1] != tokenizer.eos_token_id and add_eos_token:
        result["input_ids"].append(tokenizer.eos_token_id)
        result["attention_mask"].append(1)
    result["labels"] = result["input_ids"].copy()
    return result


def generate_user_prompt(data_point):
    if not opt.wo_instruction:
        if len(data_point['options']) > 0:
            user = instructions['direct'].format(data_point['question'], data_point['options'])
        else:
            user = instructions['direct'].format(data_point['question'], "").replace("\nOptions: ", "")
    else:
        if len(data_point['options']) > 0:
            user = f"Question: {data_point['question']}\nOptions: {data_point['options']}"
        else:
            user = f"Question: {data_point['question']}"
    return user


def generate_and_tokenize_prompt(data_point):
    user = generate_user_prompt(data_point)
    if "auggpt" in opt.data:
        assistant = f"The answer is {data_point['answer']}."
    else:
        assistant = f"{data_point['explanation']} Therefore, the answer is {data_point['answer']}."
    chat = [{"role": "user", "content": user},
            {"role": "assistant", "content": assistant}]
    
    full_prompt = tokenizer.apply_chat_template(chat, add_generation_prompt=False, tokenize=False)
    
    tokenized_full_prompt = tokenize(full_prompt)

    if not train_on_inputs:
        chat = [{"role": "user", "content": user}]
        user_prompt = tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False)
        tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
        user_prompt_len = len(tokenized_user_prompt["input_ids"])

        tokenized_full_prompt["labels"] = [
            -100
        ] * user_prompt_len + tokenized_full_prompt["labels"][
            user_prompt_len:
        ]
    return tokenized_full_prompt


parser = ArgumentParser()
parser.add_argument('--data', type=str, default='./data/ecare/full_dataset.jsonl')
parser.add_argument('--dev', type=str, default="./")
parser.add_argument('--base_model', type=str, default='/ssd1/huggingface_transformers/llama/hf_weights/7B')
parser.add_argument('--output_dir', type=str, default='./lora/7B')
parser.add_argument('--log', type=str, default='./logger/train.log')
parser.add_argument('--instruction', type=str, default='./data/instruction.json')
parser.add_argument('--template', type=str, default='./data/chat_template.json')
parser.add_argument('--batch_size', type=int, default=24)
parser.add_argument('--micro_batch_size', type=int, default=24)
parser.add_argument('--val_set_size', type=int, default=0)
parser.add_argument('--epochs', type=int, default=3)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--use_lora', type=bool, default=False)
parser.add_argument('--lora_dir', type=str, default='./lora/gpt-4')
parser.add_argument('--lora_r', type=int, default=8)
parser.add_argument('--seed', type=int, default=1024)
parser.add_argument('--lora_alpha', type=int, default=16)
parser.add_argument('--lora_dropout', type=float, default=0.05)
parser.add_argument('--lora_target_modules', type=list, default=["q_proj", "v_proj"])
parser.add_argument('--group_by_length', type=bool, default=False)
parser.add_argument('--wo_instruction', type=bool, default=False)
parser.add_argument('--resume', type=bool, default=False)
opt = parser.parse_args()

logger = get_logger(opt)
logger.info(opt)

gradient_accumulation_steps = opt.batch_size // opt.micro_batch_size
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1

if ddp:
    device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
    gradient_accumulation_steps = gradient_accumulation_steps // world_size

logger.info(f"[Device Map]: {device_map}")
logger.info(f"[World Size]: {world_size}")
logger.info(f"[accumulation_steps]: {gradient_accumulation_steps}")

logger.info('[Loading] Model & Tokenizer & Template')
chat_templates = json.load(open(opt.template, "r"))
instructions = json.load(open(opt.instruction, "r"))
tokenizer = AutoTokenizer.from_pretrained(opt.base_model, legacy=True)
if "Orca" in opt.base_model:
    tokenizer.chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n'  + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"

tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = 'left'
logger.info("[Loading] Done!")
train_on_inputs = False

if "gemma-2" in opt.base_model:
    model = AutoModelForCausalLM.from_pretrained(
        opt.base_model,
        load_in_8bit=True,
        torch_dtype="auto",
        device_map=device_map,
        attn_implementation='eager'
    )
else:
    model = AutoModelForCausalLM.from_pretrained(
        opt.base_model,
        load_in_8bit=True,
        torch_dtype="auto",
        device_map=device_map
    )

if opt.use_lora:
    model = prepare_model_for_kbit_training(model)
    model = PeftModel.from_pretrained(model, model_id=opt.lora_dir, device_map=device_map)
    for name, param in model.named_parameters():
        if 'lora' in name:
            param.requires_grad = True
else:
    model = prepare_model_for_kbit_training(model)
    config = LoraConfig(
        r=opt.lora_r,
        lora_alpha=opt.lora_alpha,
        target_modules=opt.lora_target_modules,
        lora_dropout=opt.lora_dropout,
        bias='none',
        task_type="CAUSAL_LM"
    )
    logger.info(config)
    model = get_peft_model(model, config)


logger.info("[Loading] Data.")
data = load_dataset("json", data_files=opt.data)

logger.info("[Loading] Training Data.")
data_eval = load_dataset("json", data_files=opt.dev)
train_data = data["train"].map(generate_and_tokenize_prompt)
val_data = data_eval["train"].shuffle().map(generate_and_tokenize_prompt)

model.print_trainable_parameters()

if not ddp and torch.cuda.device_count() > 1:
    model.is_parallelizable = True
    model.model_parallel = True

logger.info("[Training] Start")
os.environ["WANDB_DISABLED"] = "true"
trainer = transformers.Trainer(
    model=model,
    eval_dataset=val_data,
    train_dataset=train_data,
    args=transformers.TrainingArguments(
        disable_tqdm=True,
        per_device_train_batch_size=opt.micro_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        warmup_ratio=0.1,
        num_train_epochs=opt.epochs,
        fp16=True,
        logging_steps=100,
        seed=opt.seed,
        optim='adamw_torch',
        learning_rate=opt.lr,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        output_dir=opt.output_dir,
        save_total_limit=10,
        load_best_model_at_end=False,
        ddp_find_unused_parameters=False if ddp else None,
        group_by_length=opt.group_by_length,
        report_to="none"
    ),
    data_collator=transformers.DataCollatorForSeq2Seq(
        tokenizer, pad_to_multiple_of=2, return_tensors="pt", padding=True
    )
)


model.config.use_cache = False

trainer.train()

model.save_pretrained(opt.output_dir, safe_serialization=False)



